/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// This proto file defines messages which represent the HLO module. This is a
// full fidelity serialization of the c++ HLO constructs.
//
// Many of the protos below are simple 1-to-1 serializations of the
// corresponding C++ classes, e.g., HloModule, HloComputation, and
// HloInstruction.
//
// FIELD NAMES ARE IMPORTANT
//
// Unlike most protos, you can't safely change the names of fields, even if you
// keep the numeric ids the same. This is because we sometimes serialize these
// protos as JSON, which includes the field names in the serialization.

syntax = "proto3";

package xla;

import "tensorflow/compiler/xla/xla_data.proto";

option cc_enable_arenas = true;

// Serialization of HloInstruction.
// Next ID: 64
message HloInstructionProto {
  reserved 10;
  reserved "parameter_name";
  reserved 12;
  reserved "fused_instructions_computation";
  reserved 4;
  reserved "operand_names";
  reserved 5;
  reserved "control_predecessor_names";
  reserved 6;
  reserved "called_computation_names";
  reserved 44;
  reserved "replica_group_ids";

  string name = 1;
  string opcode = 2;
  xla.ShapeProto shape = 3;

  xla.OpMetadata metadata = 7;

  // Literal, only present for kConstant.
  xla.LiteralProto literal = 8;

  // Parameter number is only present for kParameter.
  int64 parameter_number = 9;

  // Fusion state, only present for kFusion.
  string fusion_kind = 11;

  // Index for kGetTupleElement.
  int64 tuple_index = 13;

  // Dimensions present for some operations that require reshaping or
  // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse.
  repeated int64 dimensions = 14;

  // Describes the window in a windowed operation such as convolution.
  xla.Window window = 15;

  // Describes the dimension numbers used for a convolution.
  xla.ConvolutionDimensionNumbers convolution_dimension_numbers = 16;

  // The number of feature groups. Used for a convolution. Must be a divisor of
  // the input feature dimension and output feature dimension. If not specified,
  // it will use a default value of 1.
  int64 feature_group_count = 50;

  int64 batch_group_count = 58;

  // Describes the [begin, end) index range and stride for slices.
  message SliceDimensions {
    int64 start = 1;
    int64 limit = 2;
    int64 stride = 3;
  }
  repeated SliceDimensions slice_dimensions = 17;

  // The bit sizes for a reduce-precision operation.
  int32 exponent_bits = 18;
  int32 mantissa_bits = 19;

  // Describes the [start, start + size) range size for a dynamic slice
  // ('start' is specified dynamically in the second operand of the operation).
  repeated int64 dynamic_slice_sizes = 20;

  // The padding configuration that describes the edge padding and interior
  // padding of this pad instruction. Only set for pad instructions.
  xla.PaddingConfig padding_config = 21;

  // Outfeed configuration information, only present for kOutfeed.
  bytes outfeed_config = 22;

  // The distribution requested for random number generation.
  // Only present for kRng.
  xla.RandomDistribution distribution = 23;

  // A small float number added to the variance to avoid divide-by-zero error.
  // Only present for kBatchNormTraining.
  float epsilon = 24;

  // An integer value representing the index of the feature dimension.
  // Only present for kBatchNormTraining.
  int64 feature_index = 25;

  // Represents a unique identifier for each Send/Recv instruction pair.
  // Only present for kSend or kRecv.
  int64 channel_id = 26;

  // The string representation of the infeed configuration.
  bytes infeed_config = 27;

  // Name of a external target (eg, global symbol) to call, only present for
  // kCustomCall.
  string custom_call_target = 28;

  // Opaque string, only present for kCustomCall.
  string custom_call_opaque = 53;

  // Shape of outfeed request.
  xla.ShapeProto outfeed_shape = 29;

  // Describes the dimension numbers used for a dot operation
  xla.DotDimensionNumbers dot_dimension_numbers = 30;

  // FFT type (FFT, IFFT, etc).
  xla.FftType fft_type = 31;

  // FFT length.
  repeated int64 fft_length = 32;

  // Comparison direction only used for kCompare.
  string comparison_direction = 63;

  // Gather dimension numbers.
  xla.GatherDimensionNumbers gather_dimension_numbers = 33;
  repeated int64 gather_slice_sizes = 34;

  // Compute Host.
  string channel_name = 41;
  int64 cost_estimate_ns = 42;

  // The id of this instruction.
  int64 id = 35;

  repeated int64 operand_ids = 36;
  repeated int64 control_predecessor_ids = 37;
  repeated int64 called_computation_ids = 38;

  xla.OpSharding sharding = 40;

  // Backend configuration for the instruction. Has backend-specific meaning.
  string backend_config = 43;

  // Cross replica op fields.
  repeated ReplicaGroup replica_groups = 49;
  int64 all_reduce_id = 45;
  string all_reduce_barrier = 46;

  // Whether this Send/Recv instruction transfers data to/from the host. Only
  // present for Send and Recv instructions and their SendDone and RecvDone
  // partners.
  bool is_host_transfer = 47;

  // Whether this Sort instruction should be stable.
  bool is_stable = 60;

  xla.ScatterDimensionNumbers scatter_dimension_numbers = 48;

  // Precision configuration for the instruction. Has backend-specific meaning.
  xla.PrecisionConfig precision_config = 51;

  // Collective permute field.
  repeated SourceTarget source_target_pairs = 52;

  // Sharding for kDomain instructions.
  xla.OpSharding domain_entry_sharding = 54;
  xla.OpSharding domain_exit_sharding = 55;

  // For custom call this indicates that the layouts are constrained. If
  // constrain_layout is true then the 'shape' field must contain a layout, and
  // 'operand_shapes_with_layout' must contain a shape with layout for each
  // operand.
  bool constrain_layout = 56;
  repeated xla.ShapeProto operand_shapes_with_layout = 57;

  // Options for TriangularSolve
  xla.TriangularSolveOptions triangular_solve_options = 59;

  // Options for Cholesky
  xla.CholeskyOptions cholesky_options = 62;

  // Describes how parameters behave with regards to replicas.
  xla.ParameterReplication parameter_replication = 61;
}

// Serialization of HloComputation.
message HloComputationProto {
  reserved 3;
  reserved "root_name";

  string name = 1;

  // The array of instructions is always in a valid dependency order, where
  // operands appear before their users.
  repeated HloInstructionProto instructions = 2;

  // The program shape (with layout) of this computation.

  xla.ProgramShapeProto program_shape = 4;

  // The id of this computation.
  int64 id = 5;

  // The id of the root of the computation.
  int64 root_id = 6;
}

// Serialization of an HLO schedule. An HLO schedule contains a total order of
// instructions for each non-fusion computation in the module.
message HloScheduleProto {
  message InstructionSequence {
    repeated int64 instruction_ids = 1;
  }

  // Map from computation id to sequence.
  map<int64, InstructionSequence> sequences = 1;
}

message HloInputOutputAliasProto {
  enum Kind {
    // Define a UNDEFINED_ALIAS equal to zero to get around the default-0 proto3
    // behavior and missing has_*() APIs.
    UNDEFINED_ALIAS = 0;
    // An alias setup by the user as must alias. A use setting USER_ALIAS is
    // expecting the designed output to be dropped over the given input
    // parameter number+index.
    USER_ALIAS = 1;
    // An alias setup by the compiler as part of its optimizations.
    SYSTEM_ALIAS = 2;
  }

  // The following proto describes a pair of aliased an input
  // (described by parameter number and a ShapeIndex of the parameter)
  // and an output (described by a ShapeIndex of the root
  // instruction). For example:
  //
  // entry = {
  //  output_shape_index={1},
  //  parameter_number=0,
  //  parameter_shape_index={1, 2},
  // }
  //
  // This entry indicates that the first paremter's {1, 2} element is
  // aliased with the {1} element of the root instruction.
  message AliasEntryProto {
    // ShapeIndex of the root hlo.
    repeated int64 output_shape_index = 1;
    // Number of the parameter in entry computation.
    int64 parameter_number = 2;
    // ShapeIndex of the parameter instruction.
    repeated int64 parameter_shape_index = 3;
    // The kind of alias to be setup.
    Kind kind = 4;
  }

  repeated AliasEntryProto entries = 1;
}

message DynamicParameterBindingProto {
  // A list of bindings which indicates that the `target_dim_num` in
  // the subshape `target_param_index` of parameter `target_param_num`
  // is a dynamic dimension and its real dynamic size is represented
  // by `dynamic_param_index` in parameter `dynamic_param_num`.
  //
  // As an example, imagine we have a program:
  //
  // ENTRY main {
  //   a = f32[] parameter(0)
  //   b = f32[10] parameter(1)
  //   ROOT root = (f32[], f32[10]) tuple(%a, %b)
  // }
  //
  // Let's say 'b' (param index 1) is a dynamic shape whose input has
  // an upperbound of 10 and real size is determined at runtime.'a'
  // represents the real size of b's first dimension.
  //
  // In this case, the fields are set in the following way:
  // dynamic_param_num = 1
  // dynamic_param_index = {}
  // target_param_num = 0
  // target_param_index = {}
  // target_param_dim = 0
  message Binding {
    int64 dynamic_param_num = 1;
    repeated int64 dynamic_param_index = 2;
    int64 target_param_num = 3;
    repeated int64 target_param_index = 4;
    int64 target_param_dim_num = 5;
  }

  repeated Binding entries = 1;
}

// Serialization of HloModule.
message HloModuleProto {
  string name = 1;
  string entry_computation_name = 2;
  int64 entry_computation_id = 6;

  // The array of computations is always in a valid dependency order, where
  // callees appear before their callers.
  repeated HloComputationProto computations = 3;

  // The host program shape (with layout) of the entry computation.
  xla.ProgramShapeProto host_program_shape = 4;

  // The id of this module.
  int64 id = 5;

  // The schedule for this module.
  HloScheduleProto schedule = 7;

  // Describes alias information between inputs and outputs.
  HloInputOutputAliasProto input_output_alias = 8;

  DynamicParameterBindingProto dynamic_parameter_binding = 9;
}

// Serialization of LogicalBuffer.
message LogicalBufferProto {
  // Location represents an instruction and its shape index, which uniquely
  // identifies a point where a buffer is needed.
  message Location {
    // NOTE: module_name isn't necessary, since all LogicalBuffers are
    // associated with a single HloModule.
    string computation_name = 1;
    string instruction_name = 2;
    repeated int64 shape_index = 3;
  }

  int64 id = 1;
  int64 size = 2;

  // The location where the buffer is defined.
  Location defined_at = 3;

  int64 color = 4;
}

// Serialization of BufferAllocation.
message BufferAllocationProto {
  // Assigned represents a single LogicalBuffer that is assigned to this
  // BufferAllocation.
  message Assigned {
    int64 logical_buffer_id = 1;
    int64 offset = 2;
    int64 size = 3;
  }

  int64 index = 1;
  int64 size = 2;
  bool is_thread_local = 3;
  bool is_tuple = 11;
  bool is_entry_computation_parameter = 5;
  bool is_constant = 12;
  int64 parameter_number = 6;
  repeated int64 parameter_shape_index = 10;
  bool maybe_live_out = 7;
  int64 color = 8;
  repeated Assigned assigned = 9;
}

// A trace of a HeapSimulator run.
message HeapSimulatorTrace {
  // The trace includes a list of events, where each event describes one action
  // performed by the heap simulator.
  message Event {
    enum Kind {
      ALLOC = 0;  // A memory region was allocated for the buffer.
      FREE = 1;   // A memory region was freed for the buffer.

      // A buffer was shared with another (canonical) buffer. This is similar to
      // ALLOC, except that instead of allocating a new region of memory, the
      // memory region of the canonical buffer is directly re-used. Multiple
      // buffers may share with the same canonical buffer. The lifetime of the
      // canonical buffer is extended to the union of all lifetimes.
      SHARE_WITH = 2;
    }
    Kind kind = 1;

    // The id of the LogicalBuffer that the event applies to.
    int64 buffer_id = 2;

    // The HloInstruction that the simulation was processing that caused this
    // event to occur, identified by its computation and instruction name. E.g.
    // buffers defined by instruction A are allocated when processing A.
    string computation_name = 3;
    string instruction_name = 4;

    // The id of the canonical LogicalBuffer that the buffer shares with. Only
    // set for SHARE_WITH events.
    int64 share_with_canonical_id = 5;
  }
  repeated Event events = 1;
  bool whole_module_simulation = 2;
}

// An abstraction representing a set of HLO module built to run concurrently
// across different devices.
message HloModuleGroupProto {
  string name = 1;
  repeated HloModuleProto hlo_modules = 2;
}

// Serialization of BufferAssignment.
message BufferAssignmentProto {
  // Alias represents a source LogicalBuffer, and the buffer location that
  // aliases it.
  message BufferAlias {
    int64 source_buffer_id = 1;
    LogicalBufferProto.Location location = 2;
  }

  repeated LogicalBufferProto logical_buffers = 1;
  repeated BufferAlias buffer_aliases = 2;
  repeated BufferAllocationProto buffer_allocations = 3;
  repeated HeapSimulatorTrace heap_simulator_traces = 4;
}

// Grouping message that contains all of the information above.
message HloProto {
  reserved 2;
  reserved "hlo_ordering";

  HloModuleProto hlo_module = 1;
  BufferAssignmentProto buffer_assignment = 3;
}

// Encapsulates HloProto together with the arguments, result, and
// execution_platform. This message is used for purposes such as
// analysis/replay/file-storage.
message HloSnapshot {
  // The hlo graph.
  HloProto hlo = 1;

  // The arguments passed to the graph.
  repeated LiteralProto arguments = 2;

  // The result of the graph.
  LiteralProto result = 3;

  // The name of the platform used to run the graph.
  string execution_platform = 4;
}
